#!/usr/bin/env python
# -*- coding: utf-8 -*-


import os
import copy
import time
import pickle
import numpy as np
import random
from tqdm import tqdm

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate, test_inference
from models import MLP, CNNMnist, CNNCifar
from utils import get_dataset, average_weights, exp_details, compute_grad_update,init_deterministic,aggregate_signsgd


if __name__ == '__main__':
    init_deterministic()
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('logs')

    args = args_parser()
    exp_details(args)

    if args.gpu:
        torch.cuda.set_device(args.gpu)
    device = 'cuda' if args.gpu else 'cpu'

    # load dataset and user groups
    train_dataset, test_dataset, user_groups = get_dataset(args)
    # print('user_groups',user_groups)
    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)
    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
        print('len_in',len_in)
        global_model = MLP(dim_in=len_in, dim_hidden=args.dim_hidden,
                           dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    print(global_model)

    # copy weights
    global_weights = global_model.state_dict()

    # Training
    train_accuracy = []
    print_every = 2
    test_accs=[]

    lr_epoch=args.lr
   
    for epoch in tqdm(range(args.epochs)):
        local_weights,local_grads = [],[]
        print(f'\n | Global Training Round : {epoch+1} |\n')

        global_model.train()

        m=np.random.binomial(1,args.frac,args.num_users)
        idxs_users = np.where(m!=0)

        if args.adaptive==0 and args.Byzantine>0:
            Byzantine_user=np.array(random.sample(range(args.num_users),int(np.floor(args.Byzantine*args.num_users))))
        elif args.adaptive==1 and args.Byzantine>0:
            Byzantine_user=idxs_users[0][:int(np.floor(args.Byzantine*len(idxs_users[0])))]
        else:
            Byzantine_user=[]
       
        print('lr_epoch:',lr_epoch)
        if args.mode=='FedAvg':
            for idx in idxs_users[0]:
                local_model = LocalUpdate(args=args, dataset=train_dataset,
                                        idxs=user_groups[idx], logger=logger)
                w, lr = local_model.update_weights(
                    model=copy.deepcopy(global_model), global_round=epoch, lr_epoch=lr_epoch)
                w=w.state_dict()
                local_weights.append(copy.deepcopy(w))

                # update global weights

            global_weights = average_weights(local_weights,Byzantine_user)
            # update global weights
            global_model.load_state_dict(global_weights)
        else: 
             for idx in idxs_users[0]:
                local_model = LocalUpdate(args=args, dataset=train_dataset,
                                        idxs=user_groups[idx], logger=logger)
                w, lr = local_model.update_weights(
                    model=copy.deepcopy(global_model), global_round=epoch, lr_epoch=lr_epoch)
                local_grad=compute_grad_update(args, global_model, w, lr=lr_epoch,device=device)
                if args.Byzantine>0:
                    if idx in Byzantine_user:
                        local_grad=[torch.div(local_grad_i,-1) for local_grad_i in local_grad]
                local_grads.append(local_grad)  

        # update global weights
        if args.mode=='SIGNSGD' or args.mode=='STO-SIGNSGD':
            global_model = aggregate_signsgd(epoch, args, global_model, local_grads, lr=lr_epoch, device=device,user_groups=user_groups)
        lr_epoch=lr
        # Calculate avg training accuracy over all users at every epoch
        list_acc = []
        global_model.eval()
        for c in range(args.num_users):
            local_model = LocalUpdate(args=args, dataset=train_dataset,
                                      idxs=user_groups[idx], logger=logger)
            acc = local_model.inference(model=global_model)
            list_acc.append(acc)
        train_accuracy.append(sum(list_acc)/len(list_acc))

        # print global training loss after every 'i' rounds
        if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))

        # Test inference after completion of training
        test_acc = test_inference(args, global_model, test_dataset)
        test_accs.append(test_acc)
        print(f' \n Results after {epoch} global rounds of training:')
        print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
        print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))
        print('test_accs',test_accs)

    print('test_accs',test_accs)
   
    file_name = './result/{}_{}_E[{}]_C[{}]_iid[{}]_alpha[{}]_LE[{}]_B[{}]_lr[{}]_b[{}]_beta[{}]_frac[{}]_hidden[{}]_momentum[{}]_{}_Byzantine[{}].pkl'.\
        format(args.dataset, args.model, args.epochs, args.frac, args.iid, args.alpha,
              args.local_ep, args.local_bs, args.lr,args.b,args.beta,args.frac, args.dim_hidden, args.momentum, args.mode, args.Byzantine)

    with open(file_name, 'wb') as f:
        pickle.dump([train_accuracy, test_accs], f)

    print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))
